{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Base Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from timesfm import TimesFm, freq_map, data_loader\n",
    "from adapter.utils import load_adapter_checkpoint\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "tfm = TimesFm(\n",
    "    context_len=512,\n",
    "    horizon_len=128,\n",
    "    input_patch_len=32,\n",
    "    output_patch_len=128,\n",
    "    num_layers=20,\n",
    "    model_dims=1280,\n",
    "    backend=\"cpu\",\n",
    ")\n",
    "tfm.load_from_checkpoint(repo_id=\"google/timesfm-1.0-200m\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DICT = {\n",
    "    \"ettm2\": {\n",
    "        \"boundaries\": [34560, 46080, 57600],\n",
    "        \"data_path\": \"../datasets/ETT-small/ETTm2.csv\",\n",
    "        \"freq\": \"15min\",\n",
    "    },\n",
    "    \"ettm1\": {\n",
    "        \"boundaries\": [34560, 46080, 57600],\n",
    "        \"data_path\": \"../datasets/ETT-small/ETTm1.csv\",\n",
    "        \"freq\": \"15min\",\n",
    "    },\n",
    "    \"etth2\": {\n",
    "        \"boundaries\": [8640, 11520, 14400],\n",
    "        \"data_path\": \"../datasets/ETT-small/ETTh2.csv\",\n",
    "        \"freq\": \"H\",\n",
    "    },\n",
    "    \"etth1\": {\n",
    "        \"boundaries\": [8640, 11520, 14400],\n",
    "        \"data_path\": \"../datasets/ETT-small/ETTh1.csv\",\n",
    "        \"freq\": \"H\",\n",
    "    },\n",
    "    \"elec\": {\n",
    "        \"boundaries\": [18413, 21044, 26304],\n",
    "        \"data_path\": \"../datasets/electricity/electricity.csv\",\n",
    "        \"freq\": \"H\",\n",
    "    },\n",
    "    \"traffic\": {\n",
    "        \"boundaries\": [12280, 14036, 17544],\n",
    "        \"data_path\": \"../datasets/traffic/traffic.csv\",\n",
    "        \"freq\": \"H\",\n",
    "    },\n",
    "    \"weather\": {\n",
    "        \"boundaries\": [36887, 42157, 52696],\n",
    "        \"data_path\": \"../datasets/weather/weather.csv\",\n",
    "        \"freq\": \"10min\",\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Adapter Checkpoint\n",
    "\n",
    "Specify the adapter checkpoint path, rank and the modules used to train the adapters and whether dora was employed or not."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_adapter_checkpoint(\n",
    "    model=tfm,\n",
    "    adapter_checkpoint_path=\"./checkpoints/run_20240716_163900_lyo4psz3\",\n",
    "    lora_rank=1,\n",
    "    lora_target_modules=\"all\",\n",
    "    use_dora=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"ettm1\"\n",
    "data_path = DATA_DICT[dataset][\"data_path\"]\n",
    "freq = DATA_DICT[dataset][\"freq\"]\n",
    "int_freq = freq_map(freq)\n",
    "boundaries = DATA_DICT[dataset][\"boundaries\"]\n",
    "\n",
    "data_df = pd.read_csv(open(data_path, \"r\"))\n",
    "\n",
    "ts_cols = [col for col in data_df.columns if col != \"date\"]\n",
    "num_cov_cols = None\n",
    "cat_cov_cols = None\n",
    "\n",
    "context_len = 512\n",
    "pred_len = 96\n",
    "\n",
    "num_ts = len(ts_cols)\n",
    "batch_size = 16\n",
    "\n",
    "dtl = data_loader.TimeSeriesdata(\n",
    "    data_path=data_path,\n",
    "    datetime_col=\"date\",\n",
    "    num_cov_cols=num_cov_cols,\n",
    "    cat_cov_cols=cat_cov_cols,\n",
    "    ts_cols=np.array(ts_cols),\n",
    "    train_range=[0, boundaries[0]],\n",
    "    val_range=[boundaries[0], boundaries[1]],\n",
    "    test_range=[boundaries[1], boundaries[2]],\n",
    "    hist_len=context_len,\n",
    "    pred_len=pred_len,\n",
    "    batch_size=num_ts,\n",
    "    freq=\"15min\",\n",
    "    normalize=True,\n",
    "    epoch_len=None,\n",
    "    holiday=False,\n",
    "    permute=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_batches = dtl.tf_dataset(mode=\"test\", shift=pred_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mae_losses = []\n",
    "for batch in tqdm(test_batches.as_numpy_iterator()):\n",
    "    past = batch[0]\n",
    "    actuals = batch[3]\n",
    "    _, forecasts = tfm.forecast(list(past), [0] * past.shape[0])\n",
    "    forecasts = forecasts[:, 0 : actuals.shape[1], 5]\n",
    "    mae_losses.append(np.abs(forecasts - actuals).mean())\n",
    "\n",
    "print(f\"MAE: {np.mean(mae_losses)}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tanmay_tfm_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
